# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate target CFG for SQL given Spider databases.

Note that for simplicity and because it has minimal impact on accuracy, the
grammar generated by this file is slightly different than the one used for the
experiments in the paper, which was specialized for each database.
"""

import json

from absl import app
from absl import flags

from language.compgen.nqg.model.parser.inference.targets import target_grammar

from tensorflow.io import gfile


FLAGS = flags.FLAGS

flags.DEFINE_string("spider_tables", "", "Tables JSON file for Spider.")

flags.DEFINE_string("output", "", "Output rules txt file.")


RULES = [
    "ROOT => ##ROOT union ##ROOT",
    "ROOT => ##ROOT intersect ##ROOT",
    "ROOT => ##ROOT except ##ROOT",
    "ROOT => select ##EXPR",
    "EXPR => ##T from ##T",
    "EXPR => ##T from ##FROM",
    "EXPR => ##EXPR where ##ANYTHING",
    "EXPR => ##EXPR group by ##ANYTHING",
    "EXPR => ##EXPR order by ##ANYTHING",
    "EXPR => ##EXPR limit ##ANYTHING",
    "T => ( ##ROOT )",
    "T => ##T - ##T",
    "T => ##T + ##T",
    "T => ##T / ##T",
    "T => ##T * ##T",
    "T => ( ##T )",
    "T => distinct ##T",
    "T => distinct ( ##T )",
    "T => ##T , ##T",
    "T => *",
    "T => ##T as ##T",
    "T => t1",
    "T => t2",
    "T => t3",
    "T => t4",
    "T => t5",
    "T => t6",
    "T => t7",
    "T => t8",
    "T => t9",
    "T => ##T . ##T",
    "T => count ( ##T )",
    "T => sum ( ##T )",
    "T => avg ( ##T )",
    "T => max ( ##T )",
    "T => min ( ##T )",
    "T => count ( ##T )",
    "T => sum ( ##T )",
    "T => avg ( ##T )",
    "T => max ( ##T )",
    "T => min ( ##T )",
    "FROM => ##FROM join ##T",
    "FROM => ##T join ##T",
    "FROM => ##FROM on ##JOIN_COL",
    "FROM => ##FROM and ##JOIN_COL",
    "JOIN_COL => ##T = ##T"
]


def load_json(filepath):
  with gfile.GFile(filepath, "r") as reader:
    text = reader.read()
  return json.loads(text)


def main(unused_argv):
  tables_json = load_json(FLAGS.spider_tables)

  rules = []
  for rule_string in RULES:
    rules.append(target_grammar.TargetCfgRule.from_string(rule_string))

  schema_elements = set()
  for table in tables_json:
    columns = set(name for _, name in table["column_names_original"])
    tables = set(table["table_names_original"])
    schema_elements |= columns
    schema_elements |= tables

  for name in schema_elements:
    rules.append(
        target_grammar.TargetCfgRule.from_string("T => %s" % name.lower()))

  target_grammar.rules_to_txt_file(rules, FLAGS.output)


if __name__ == "__main__":
  app.run(main)
